package com.feizhaiyou.encrypt.advice;

import com.feizhaiyou.encrypt.handler.SecurityHandler;
import com.feizhaiyou.encrypt.handler.SensitiveHandler;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryAware;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.core.OrderComparator;
import org.springframework.core.Ordered;

import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.util.*;

/**
 * @author ls
 */
@Slf4j
public abstract class AbstractSecurityAdvice implements BeanFactoryAware, InitializingBean {

    public int DEFAULT_CLEAN_DEPTH;

    public List<String> STANDARD_CLASS = new ArrayList<>();

    protected ConfigurableListableBeanFactory beanFactory;

    protected SecurityHandler securityHandler;
    protected SensitiveHandler sensitiveHandler;


    /**
     * 处理加解密
     *
     * @param value       要处理的字符串
     * @param annotations 处理注解
     * @return 最终字符串
     */
    public abstract String handleSecurity(String value, Annotation[] annotations);


    /**
     * 加解密
     *
     * @param currentTime   当前递归深度
     * @param maxCleanDepth 最大深度
     * @param result        待处理的结果
     * @param annotations   处理注解
     * @return 处理后的结果
     * @throws Exception
     */
    protected Object handleObject(
            int currentTime, // 当前递归深度
            int maxCleanDepth, // 最大递归深度
            Object result,  // 递归字段值
            Annotation[] annotations
    ) throws Exception {
        if (Objects.isNull(result)) {
            return null;
        }

        if (currentTime >= maxCleanDepth) {
            log.warn("currentTime：{}，maxCleanDepth：{}", currentTime, maxCleanDepth);
            return result;
        }

        // 递归深度
        final int nextDepth = currentTime + 1;

        // 需要改变的值的类型
        Class<?> resultClass = result.getClass();

        if (String.class.isAssignableFrom(resultClass)) {
            // 只有字符串类型才能进行加解密
            return handleSecurity((String) result, annotations);
        } else if (resultClass.isArray()) {
            // 数组类型
            wrapperNewObjArray((Object[]) result, nextDepth, annotations);
        } else if (List.class.isAssignableFrom(resultClass)) {
            // List类型
            wrapperNewObjList((List<Object>) result, nextDepth, annotations);
            return result;
        } else if (Set.class.isAssignableFrom(resultClass)) {
            // Set类型
            wrapperNewObjSet((Set<Object>) result, nextDepth, annotations);
            return result;
        } else if (Map.class.isAssignableFrom(resultClass)) {
            // Map类型
            wrapperNewObjMap((Map<Object, Object>) result, nextDepth, annotations);
            return result;
        } else if (isStandardClass(result)) {
            // 其他Java类
            // 获取所有字段
            Field[] declaredFields = findAllDeclaredFields(resultClass);
            for (Field field : declaredFields) {
                field.setAccessible(true);
                Object value = field.get(result);
                if (value == null) {
                    continue;
                }

                annotations = field.getDeclaredAnnotations();

                Class<?> clazz = value.getClass();
                if (String.class.isAssignableFrom(clazz)) {
                    field.set(result, handleSecurity((String) value, annotations));
                } else if (clazz.isArray()) {
                    wrapperNewObjArray((Object[]) value, nextDepth, annotations);
                } else if (List.class.isAssignableFrom(clazz)) {
                    wrapperNewObjList((List<Object>) value, nextDepth, annotations);
                } else if (Set.class.isAssignableFrom(clazz)) {
                    wrapperNewObjSet((Set<Object>) value, nextDepth, annotations);
                } else if (Map.class.isAssignableFrom(clazz)) {
                    wrapperNewObjMap((Map<Object, Object>) value, nextDepth, annotations);
                } else {
                    field.set(result, handleObject(nextDepth, maxCleanDepth, value, annotations));
                }
            }
        }

        return result;
    }


    /**
     * 获取类中所有属性字段
     *
     * @param resultClass
     * @return
     */
    private Field[] findAllDeclaredFields(Class<?> resultClass) {
        Set<Field> fields = new HashSet<>();
        Class<?> currentClass = resultClass;
        do {
            fields.addAll(Arrays.asList(currentClass.getDeclaredFields()));
            currentClass = currentClass.getSuperclass();
        } while (currentClass != null);
        return fields.toArray(new Field[0]);
    }

    /**
     * 处理数组
     *
     * @param valueList
     * @param nextDepth
     * @param annotations
     * @throws Exception
     */
    private void wrapperNewObjArray(Object[] valueList, int nextDepth, Annotation[] annotations) throws Exception {
        for (int i = 0; i < valueList.length; i++) {
            try {
                Object value = valueList[i];
                valueList[i] = handleObject(nextDepth, DEFAULT_CLEAN_DEPTH, value, annotations);
            } catch (UnsupportedOperationException e) {
                log.error("value:{} class:{} is unModify!", Arrays.toString(valueList), valueList.getClass().getSimpleName());
                return;
            }
        }
    }

    /**
     * 处理List类型
     *
     * @param valueList
     * @param nextDepth
     * @param annotations
     * @throws Exception
     */
    private void wrapperNewObjList(List<Object> valueList, int nextDepth, Annotation[] annotations) throws Exception {
        for (int i = 0; i < valueList.size(); i++) {
            try {
                Object value = valueList.get(i);
                valueList.set(i, handleObject(nextDepth, DEFAULT_CLEAN_DEPTH, value, annotations));
            } catch (UnsupportedOperationException e) {
                log.error("value:{} class:{} is unModify!", valueList, valueList.getClass().getSimpleName());
                return;
            }
        }
    }

    /**
     * 处理map类型
     *
     * @param objectMap
     * @param nextDepth
     * @param annotations
     * @throws Exception
     */
    private void wrapperNewObjMap(Map<Object, Object> objectMap, int nextDepth, Annotation[] annotations) throws Exception {
        for (Object key : objectMap.keySet()) {
            try {
                objectMap.put(key, handleObject(nextDepth, DEFAULT_CLEAN_DEPTH, objectMap.get(key), annotations));
            } catch (UnsupportedOperationException e) {
                log.error("value:{} class:{} is unModify!", objectMap, objectMap.getClass().getSimpleName());
            }
        }
    }

    /**
     * 处理set类型
     *
     * @param objectSet
     * @param nextDepth
     * @param annotations
     * @throws Exception
     */
    private void wrapperNewObjSet(Set<Object> objectSet, int nextDepth, Annotation[] annotations) throws Exception {
        List<Object> objectList = new LinkedList<>();
        for (Object obj : objectSet) {
            objectList.add(handleObject(nextDepth, DEFAULT_CLEAN_DEPTH, obj, annotations));
        }
        try {
            objectSet.clear();
        } catch (UnsupportedOperationException e) {
            log.error("value:{} class:{} is unModify!", objectSet, objectSet.getClass().getSimpleName());
        }
        objectSet.addAll(objectList);
    }

    boolean isStandardClass(Object result) {
        Class<?> clazz;
        if (result instanceof Class) {
            clazz = (Class<?>) result;
        } else {
            clazz = result.getClass();
        }
        for (String standardClass : STANDARD_CLASS) {
            if (clazz.getName().startsWith(standardClass) || clazz.getName().matches(standardClass)) {
                return true;
            }
        }
        return false;
    }


    @Override
    public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
        if (!(beanFactory instanceof ConfigurableListableBeanFactory)) {
            throw new IllegalArgumentException(
                    "AdvisorAutoProxyCreator requires a ConfigurableListableBeanFactory: " + beanFactory);
        }
        this.beanFactory = (ConfigurableListableBeanFactory) beanFactory;
    }

    private static void sortProcessors(List<?> list, ConfigurableListableBeanFactory beanFactory) {
        Comparator<Object> comparatorToUse = null;
        if (beanFactory instanceof DefaultListableBeanFactory) {
            comparatorToUse = ((DefaultListableBeanFactory) beanFactory).getDependencyComparator();
        }
        if (comparatorToUse == null) {
            comparatorToUse = OrderComparator.INSTANCE;
        }
        list.sort(comparatorToUse);
    }

    @Override
    public void afterPropertiesSet() throws Exception {
        // 获取所有SecurityHandler类型bean
        List<SecurityHandler> securityHandlers = getBeanByType(SecurityHandler.class);
        List<SensitiveHandler> sensitiveHandlers = getBeanByType(SensitiveHandler.class);
        this.securityHandler = securityHandlers.get(0);
        this.sensitiveHandler = sensitiveHandlers.get(0);
    }

    /**
     * 获取Spring所有指定类型的bean 并按Ordered排序
     *
     * @param tClass
     * @param <T>
     * @return
     */
    private  <T> List<T> getBeanByType(Class<T> tClass) {
        List<T> list = new ArrayList<>();
        String[] beanNames = beanFactory.getBeanNamesForType(tClass, true, false);
        List<String> orderedNames = new ArrayList<>();
        List<String> nonOrderedNames = new ArrayList<>();
        for (String beanName : beanNames) {
            if (beanFactory.isTypeMatch(beanName, Ordered.class)) {
                orderedNames.add(beanName);
            } else {
                nonOrderedNames.add(beanName);
            }
        }
        List<T> orderedBeans = new ArrayList<>(orderedNames.size());
        for (String oName : orderedNames) {
            T t = beanFactory.getBean(oName, tClass);
            orderedBeans.add(t);
        }

        sortProcessors(orderedBeans, beanFactory);

        List<T> nonOrderedBeans = new ArrayList<>(nonOrderedNames.size());
        for (String nName : nonOrderedNames) {
            T t = beanFactory.getBean(nName, tClass);
            nonOrderedBeans.add(t);
        }

        list.addAll(orderedBeans);
        list.addAll(nonOrderedBeans);
        return list;
    }

}
